import torch
import copy
import torch.distributed as dist
import sys
sys.path.append('extra_utils')
from extra_utils.distributed_utils import init_distributed_mode, cleanup, is_main_process


def get_matrix(class_mean_list, conf, rank):
    # 得到关于类平均值的 类-epoch 矩阵.并通过分布式通信包,得到各client的类-epoch 矩阵.
    class_epoch_matrix = torch.cat([class_mean_list[i] for i in range(conf["global_epochs"])], dim=1)
    # print("Rank {} : {}".format(rank, class_epoch_matrix))
    # torch.Size([10, 10])
    class_epoch_matrix_1 = copy.deepcopy(class_epoch_matrix)
    class_epoch_matrix_2 = copy.deepcopy(class_epoch_matrix)
    if is_main_process():
        dist.recv(class_epoch_matrix_1, src=1)
        dist.recv(class_epoch_matrix_2, src=2)
    elif rank == 1:
        dist.send(class_epoch_matrix_1, dst=0)
    else:
        dist.send(class_epoch_matrix_2, dst=0)
    dist.barrier()  # 等待数据通信完成再继续下面的步骤
    if is_main_process():
        print(class_epoch_matrix)
        print(class_epoch_matrix_1)
        print(class_epoch_matrix_2)
    return class_epoch_matrix, class_epoch_matrix_1, class_epoch_matrix_2
